%% ADD USAGE INSTRUCTIONS
%
%
%

function [rs sig_cells] = analyzePatternsWM( area_name, type_name, st1_msec, del1_msec, st2_msec, del2_msec, start_time_msecs, end_time_msecs,significance_level, save_seconds)

%global samples cats
global mfr rows cols
msecs_per_sec = 1000;

if nargin < 9
    significance_level=0.05;
    save_seconds = 1;
else 
    if nargin < 8
        save_seconds = 10;
    end
end

load ('patternhist2.dat');
num_patterns = max( patternhist2(:,2) ) + 1;
num_phases = 6;
phase_start_msec = zeros(1, num_phases);
phase_end_msec = zeros(1, num_phases);
phase_start_msec(1) = 0 + 50;
phase_end_msec(1) = st1_msec - 50;
phase_start_msec(2) = st1_msec + 50;
phase_end_msec(2) = st1_msec + del1_msec - 50;
% st2 SAME
phase_start_msec(3) = st1_msec + del1_msec + 50;
phase_end_msec(3) = st1_msec + del1_msec + st2_msec - 50;
% st2 DIFF
phase_start_msec(4) = st1_msec + del1_msec + 50;
phase_end_msec(4) = st1_msec + del1_msec + st2_msec - 50;
% del2 SAME
phase_start_msec(5) = st1_msec + del1_msec + st2_msec + 50;
phase_end_msec(5) = st1_msec + del1_msec + st2_msec + del2_msec - 50;
% del2 DIFF
phase_start_msec(6) = st1_msec + del1_msec + st2_msec + 50;
phase_end_msec(6) = st1_msec + del1_msec + st2_msec + del2_msec - 50;

for p = 1 : num_phases
    disp('>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> NEW PHASE <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<');
    
    mfr=[];
    
    switch p
        case 1
            prefix = 'ST1';
        case 2
            prefix = 'DEL1';
        case 3
            prefix = 'ST2 SAME';
        case 4
            prefix = 'ST2 DIFF';
        case 5
            prefix = 'DEL2 SAME';
        case 6
            prefix = 'DEL2 DIFF';
    end

    %phase_start_msec = phase_start_msec + 1;
    %phase_end_msec = phase_end_msec + 1;

    read_groups; % Assumes groups.dat has the neural area data

    index = [];
    for i = 1:length(areas)
        if strcmp( deblank(areas{i}), area_name) && strcmp( type_name, deblank(celltype{i})) 
            index = i
            break;
        end
    end
    if isempty(index)
        disp([area_name ' ' type_name ' is not in groups.dat.']);
        return;
    end

    first = neuron_id(index,1) 
    last  = neuron_id(index,2) 

    num_neurons = max(neuron_id(:,2))+1

    first = first + 1;
    last = last + 1;
    rows = la(index);
    cols = lb(index);
    N = rows*cols

    count = zeros( 1, num_patterns );

    % find the file we need to open
    filename = '';
    
    t_adjusted = patternhist2(:,1) - st1_msec - del1_msec;

    trial_indices_to_analyze = find(t_adjusted>=start_time_msecs & t_adjusted<end_time_msecs);

    for i = trial_indices_to_analyze'

        %t = patternhist2( i, 1 ) - st1_msec - del1_msec;
        t = t_adjusted(i);

        % find the file we need to open
        start_file = ceil( ( t+phase_start_msec(p) )/(save_seconds*1000) );
        stop_file = ceil( ( t+phase_end_msec(p) - 1 )/(save_seconds*1000) );

        sdat = [];
        for file = 0 : (stop_file-start_file)
            filename = ['spikes' num2str( (start_file + file)*save_seconds*1000, 15 ) '.dat']
            sdat_temp = load(filename);
            sdat = [sdat; sdat_temp];
        end

        max_time_msec = ceil( ( t+phase_start_msec(p) )/(save_seconds*1000) )*(stop_file-start_file+1)*msecs_per_sec
        S=sparse( sdat(:,2)+1, sdat(:,1)+1, 1, max(num_neurons,max(sdat(:,1)+1)), max(max_time_msec,max(sdat(:,1))) );


        current_pattern = patternhist2( i, 2 ) + 1;
        offset_pattern = patternhist2( i, 3 ) + 1;
        
        if ( p == 3 || p == 5 )
            if ( current_pattern == offset_pattern )
                count(current_pattern) = count(current_pattern) + 1;
                c=count(current_pattern);
                mfr(current_pattern).spikes(:, c ) = full(sum( S( first:last, ( t+phase_start_msec(p) ):( t+phase_end_msec(p)-1 )), 2));
            end
        else
            if ( p == 4 || p == 6 )
                if ( current_pattern ~= offset_pattern )
                    count(current_pattern) = count(current_pattern) + 1;
                    c=count(current_pattern);
                    mfr(current_pattern).spikes(:, c ) = full(sum( S( first:last, ( t+phase_start_msec(p) ):( t+phase_end_msec(p)-1 )), 2));
                end
            else
                count(current_pattern) = count(current_pattern) + 1;
                c=count(current_pattern);
                mfr(current_pattern).spikes(:, c ) = full(sum( S( first:last, ( t+phase_start_msec(p) ):( t+phase_end_msec(p)-1 )), 2));
            end
        end
        
        % Gather mean rates for stats.
        disp('****************************')
    end

    count

    k=0;
    figure(p*100+k); k= k + 1;


    disp('Mean firing rate of cells to each pattern');
    meanresponses=zeros(rows,cols,num_patterns);
    for i = 1:num_patterns
        %figure(p*100+k); k= k + 1;
        side = ceil(sqrt(num_patterns));
        subplot(side,side,i);

        meanresponses(:,:,i)=rot90(reshape( mean(mfr(i).spikes,2)*1000/(phase_end_msec(p)-phase_start_msec(p)), rows, cols )); % ???
        imagesc(meanresponses(:,:,i));
        title([prefix ': ' int2str(i)],'FontSize',10);
        %colormap(gray);
        colorbar;
        set(gca,'XTickLabel',{});
        %title([prefix ': ' 'Mean rate (sp/sec) of cells in area ' area_name type_name ' to pattern ' num2str(i) ]);
    end

    figure(p*100+k); k= k + 1;

    imagesc(mean(meanresponses,3));
    colorbar
    title([prefix ': ' 'Mean firing rate for all neurons across all patterns'],'FontSize',10);
    drawnow

    figure(p*100+k); k= k + 1;
    sortedmfr=sort(meanresponses,3,'descend');
    sortedmfr=reshape(sortedmfr,rows*cols,num_patterns);
    %plot(mean(sortedmfr));
    errorbar(mean(sortedmfr),std(sortedmfr));
    title([prefix ': ' 'Grand tuning curve over all neurons (mfr +/- std.)'],'FontSize',10);
    ylabel('Mean firing rate (sp/sec)','FontSize',10);
    xlabel('pattern number sorted from best to worst','FontSize',10);
    drawnow

    figure(p*100+k); k= k + 1;
    imagesc(sortedmfr);

    figure(p*100+k); k= k + 1;
    b=bitsinfo(sortedmfr,2.5:5:(max(sortedmfr(:,1))+2.49)); % bins of ~ 5 sp/sec.
    hist(b,20);
    title([prefix ': ' 'histogram of bits of information conveyed by neurons'],'FontSize',10);
    drawnow

    figure(p*100+k);k=k+1;
    [best_response best_indices]=max(meanresponses,[],3)
    imagesc(best_indices, [0 num_patterns]); % start scale at 0 for comparison with significance plot
    colorbar
    axis off
    title([prefix ': ' 'Optimal pattern for each cell'],'FontSize',10);
    drawnow

    figure(p*100+k);k=k+1;
    hist(best_indices(:),num_patterns);
    title([prefix ': ' 'Histogram of neurons responding best to each pattern'],'FontSize',10);
    drawnow

    % See if we can classify the samples given the network response vector
    train_classifier = true;
    if train_classifier

        samples = [];
        cats=[];
        testing_samples = [];
        testing_cats=[];

        for i = 1:num_patterns

            training_patterns = floor(2/3*count(i));
            testing_patterns = count(i)-training_patterns;

            samples =[samples; mfr( i ).spikes(:,1:training_patterns)'];
            cats = [cats ones(1,training_patterns)*i];

            testing_samples = [testing_samples; mfr(i).spikes(:,training_patterns+1:end)'];
            testing_cats = [testing_cats ones(1,testing_patterns)*i];
        end

        % Try a backprop classifier
        T = full(ind2vec(cats)) ; % turn category indices into appropriate outputs.

        bpnet=newlind(samples',T);
        Y=sim(bpnet,samples');

        [m Yc] = max(Y);

        percent_correct_bp = sum(Yc==(vec2ind(T)))/length(Yc)*100
        cMat1 = confusionmat(cats,Yc) % the confusion matrix

        % now on test set.
        T = full(ind2vec(testing_cats)) ; % turn category indices into appropriate outputs.
        Y=sim(bpnet,testing_samples');

        [m Yc] = max(Y);

        percent_correct_bp = sum(Yc==(vec2ind(T)))/length(Yc)*100
        cMat1 = confusionmat(testing_cats,Yc) % the confusion matrix


    end



    % Show significantly preferential responses to each pattern.



    for i = 1:num_patterns
        for j = 1:num_patterns
            if ( count(j) > 1 && count(i) > 1)
                for n = 1:N
                    [junk significant] = ranksum( mfr( i ).spikes(n, : ) , mfr( j ).spikes(n, : ) ,significance_level);
                    rs(n,i,j) = significant && mean(mfr( i ).spikes(n, : ))  > mean(mfr( j ).spikes(n, : )); 
                end
                num_significant(i,j) = sum(rs(:,i,j));
            end
        end
        disp('.');
    end

    disp('number of cells which distinguish pattern "row" from pattern "col"');
    num_significant

    discriminators = zeros(1,num_patterns);
    merged_discriminator_map = zeros(N,1);

    for i = 1:num_patterns
        temp = ones(N,1);
        for j = 1:num_patterns
            if i ~= j
                temp = temp & rs(:,i,j);
            end
        end

        merged_discriminator_map = merged_discriminator_map + i*temp;
        discriminators(i) = sum(temp);

        figure(p*100+k); k= k + 1;

        %subplot(num_patterns,1,i);
        imagesc(rot90(reshape( temp, rows, cols )))
        colormap(gray);
        title({[prefix ': ' 'Cells in area ' area_name type_name ' that distinguish pattern ' num2str(i)],[ ' from ALL others statistically (p< ' num2str(significance_level) ', Wilcoxon RS); t=' num2str(phase_start_msec(p)) ' ' num2str(phase_end_msec(p)) ' and respond to maximally to it.' ]},'FontSize',10);
        size(temp)
        disp('Significant cell numbers (start counting at 0 as in C): ');
        sig_cells(i).list = (find(temp)+first-2)';
    end

    figure(p*100+k);k=k+1;
    bar(discriminators);
    title({[prefix ': ' 'Histogram of cells in area ' area_name type_name ' that distinguish one pattern from ALL others'],['statistically (p< ' num2str(significance_level) ', Wilcoxon RS); t=' num2str(phase_start_msec(p)) ' ' num2str(phase_end_msec(p)) ' and respond to maximally to it.' ]},'FontSize',10);

    figure(p*100+k);k=k+1;
    imagesc(rot90(reshape( merged_discriminator_map, rows, cols )))
    title({[prefix ': ' 'Cells in area ' area_name type_name ' that distinguish one pattern from ALL others'],['statistically (p< ' num2str(significance_level) ', Wilcoxon RS); t=' num2str(phase_start_msec(p)) ' ' num2str(phase_end_msec(p)) ' and respond to maximally to it.' ]},'FontSize',10);
    axis off
    colorbar

    merged_discriminator_map
    for i = find(merged_discriminator_map')
       disp(['Responses for unit ' num2str(i)]);
       for j = 1:num_patterns 
          disp([j mfr(j).spikes(i,:)]); 
       end
    end
end

